Triton入门精选Puzzles

原版:

srush/Triton-Puzzles: Puzzles for learning Triton

缩减版:

SiriusNEO/Triton-Puzzles-Lite: Puzzles for learning Triton, play it with minimal environment configuration!

参考答案:

alexzhang13/Triton-Puzzles-Solutions: Personal solutions to the Triton Puzzles

Puzzle 2: 涉及block的常量相加

image-20241127214733649

B0小于x的长度N0,注意不要数组越界,load和store都需要添加mask

@triton.jit
def add_mask2_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr):
    pid_x = tl.program_id(0)
    off_x = tl.arange(0, B0) + pid_x * B0
    mask = off_x < N0
    x = tl.load(x_ptr + off_x,mask,0)
    x = add2_spec(x)
    tl.store(z_ptr + off_x, x, mask)

Puzzle 4:涉及block的数组相加

image-20241127225317229

@triton.jit
def add_vec_block_kernel(
    x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr
):
    block_id_x = tl.program_id(0)
    block_id_y = tl.program_id(1)

    off_x = block_id_x*B0 + tl.arange(0,B0)
    off_y = block_id_y*B1 + tl.arange(0,B1)

    mask_x = off_x < N0
    mask_y = off_y < N1

    x = tl.load(x_ptr + off_x,mask_x,0)
    y = tl.load(y_ptr + off_y,mask_y,0)

    # x是一行,y是一列
    z = x[None,:]+y[:,None]
    off_z = off_x[None,:]+off_y[:,None]*N0
    mask_z = mask_x[None,:] & mask_y[:,None]
    tl.store(z_ptr+off_z,z,mask_z)
    tl.store(z_ptr+off_z,z)

    return

一定要有mask!!!

Puzle 5: Relu

image-20241128144500629

@triton.jit
def mul_relu_block_kernel(
    x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr
):
    block_id_x = tl.program_id(0)
    block_id_y = tl.program_id(1)

    off_x = block_id_x*B0 + tl.arange(0,B0)
    off_y = block_id_y*B1 + tl.arange(0,B1)

    mask_x = off_x < N0
    mask_y = off_y < N1

    x = tl.load(x_ptr+off_x,mask_x,0)
    y = tl.load(y_ptr+off_y,mask_y,0)

    z = x[None,:]*y[:,None]
    z = tl.where(z>=0,z,0)

    off_z = off_x[None,:] + off_y[:,None]*N0
    mask_z = mask_x[None,:]&mask_y[:,None]

    tl.store(z_ptr+off_z,z,mask_z)
    return

results matching ""

    No results matching ""